{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f28a88b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8505fa405ed14e359ed3a68ef656a446",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 21.7 s, sys: 5.67 s, total: 27.4 s\n",
      "Wall time: 1min 19s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "import contextlib\n",
    "\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from peft import LoraConfig\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    BitsAndBytesConfig,\n",
    "    HfArgumentParser,\n",
    "    AutoTokenizer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "\n",
    "from trl import SFTTrainer\n",
    "\n",
    "from peft import (\n",
    "    prepare_model_for_kbit_training,\n",
    "    LoraConfig,\n",
    "    get_peft_model,\n",
    "    PeftModel\n",
    ")\n",
    "\n",
    "# bnb_config = BitsAndBytesConfig(\n",
    "#     load_in_4bit=True,\n",
    "#     bnb_4bit_quant_type=\"nf4\",\n",
    "#     bnb_4bit_compute_dtype=\"bfloat16\",\n",
    "#     bnb_4bit_use_double_quant=False,\n",
    "# )\n",
    "\n",
    "# device_map = {\"\": 0}\n",
    "base_model = \"bigcode/starcoderbase-1b\"\n",
    "model = \"smangrul/full-finetune-starcoderbase-3b-deepspeed-colab\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model, quantization_config=None, \n",
    "    device_map=None, \n",
    "    trust_remote_code=True, \n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "\n",
    "if not hasattr(model, \"hf_device_map\"):\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7b8fdd99",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_code_completion(prefix, suffix):\n",
    "    text = prompt = f\"\"\"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>\"\"\"\n",
    "    model.eval()\n",
    "    outputs = model.generate(input_ids=tokenizer(text, return_tensors=\"pt\").input_ids.cuda(), \n",
    "                             max_new_tokens=512,\n",
    "                             temperature=0.3,\n",
    "                             top_k=50,\n",
    "                             top_p=0.95,\n",
    "                             do_sample=True,\n",
    "                             repetition_penalty=1.2,\n",
    "                             #stopping_criteria=stopping_criteria\n",
    "                            )\n",
    "    return tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "36512a6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<fim_prefix>\n",
      "# Here is the correct implementation of the two sum code exercise\n",
      "# time complexity: O(N)\n",
      "# space complexity: O(N)\n",
      "def two_sum(arr, target_sum):\n",
      "<fim_suffix>\n",
      "two_sum([3,4,5], 7)\n",
      "<fim_middle>    \"\"\"\n",
      "    Two Sum Problem\n",
      "\n",
      "    Args:\n",
      "        arr (list[int]):\n",
      "            List containing integers.\n",
      "        target_sum (int):\n",
      "            The desired sum for which to find the two smallest elements in an array.\n",
      "\n",
      "    Returns:\n",
      "        Tuple[List[int]]: A tuple with the indices of the found values.\n",
      "    \"\"\"\n",
      "    # Initialize our variables\n",
      "    low = 0\n",
      "    high = len(arr) - 1\n",
      "\n",
      "    while low < high:\n",
      "        mid = (low + high) // 2\n",
      "\n",
      "        if arr[mid] == target_sum:\n",
      "            return [mid, mid + 1]\n",
      "        elif arr[mid] > target_sum:\n",
      "            high = mid - 1\n",
      "        else:\n",
      "            low = mid + 1\n",
      "\n",
      "    return [low, high + 1]<|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "prefix = \"\"\"\n",
    "# Here is the correct implementation of the two sum code exercise\n",
    "# time complexity: O(N)\n",
    "# space complexity: O(N)\n",
    "def two_sum(arr, target_sum):\n",
    "\"\"\"\n",
    "\n",
    "suffix = \"\"\"\n",
    "two_sum([3,4,5], 7)\n",
    "\"\"\"\n",
    "print(get_code_completion(prefix, suffix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3737d667",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<fim_prefix>from peft import LoraConfig, TaskType\n",
      "\n",
      "peft_config = LoraConfig(<fim_suffix>)<fim_middle>\n",
      "    r=8, lora_alpha=32, target_modules=[\"q\", \"v\"], lora_dropout=0.15, bias=\"none\"\n",
      "<|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "prefix = \"\"\"from peft import LoraConfig, TaskType\n",
    "\n",
    "peft_config = LoraConfig(\"\"\"\n",
    "\n",
    "suffix = \")\"\n",
    "print(get_code_completion(prefix, suffix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cbf83544",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<fim_prefix>from accelerate import Accelerator\n",
      "\n",
      "accelerator = Accelerator()\n",
      "\n",
      "model, optimizer, training_dataloader, scheduler = <fim_suffix><fim_middle>accelerator.prepare(\n",
      "    model, optimizer, training_dataloader, scheduler\n",
      ")\n",
      "<|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "prefix = \"\"\"from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "\n",
    "model, optimizer, training_dataloader, scheduler = \"\"\"\n",
    "\n",
    "suffix = \"\"\"\"\"\"\n",
    "print(get_code_completion(prefix, suffix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdce5f84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
